# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

import math
import os
import os.path as osp
import warnings
from dataclasses import asdict
from typing import Any, Dict, List, Optional, Sequence, Tuple

import torch
import transformers
from huggingface_hub import file_exists, repo_exists
from huggingface_hub.utils import HFValidationError
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    PretrainedConfig,
    PreTrainedModel,
    PreTrainedTokenizer,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled

# from .conversation import *
from .conversation import SeparatorStyle, default_conversation

SENTINEL_TOKEN = "<vila/sentinel>"
MEDIA_TOKENS = {
    "image": "<image>",
    "video": "<vila/video>",
}

# from llava.model.utils import packing
# from llava.utils.logging import logger
# from llava.utils.tokenizer import infer_stop_tokens

DUMMY_CONVERSATION = [
    {"from": "human", "value": "question"},
    {"from": "gpt", "value": "answer"},
] * 10


def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
    return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]


def has_tokenizer(repo_id_or_path: str) -> bool:
    # Check if the tokenizer is in a local directory
    if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")):
        return True

    # Check if the tokenizer is in a Hugging Face Hub repo
    try:
        return repo_exists(repo_id_or_path) and file_exists(repo_id_or_path, "tokenizer_config.json")
    except HFValidationError:
        return False


def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
    if not hasattr(tokenizer, "sentinel_token"):
        tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
        tokenizer.sentinel_token = SENTINEL_TOKEN
        tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)


def tokenize_conversation_legacy(
    messages: Sequence[Dict[str, str]],
    tokenizer: transformers.PreTrainedTokenizer,
    add_generation_prompt: bool = False,
    overrides: Optional[Dict[str, str]] = None,
    no_system_prompt: bool = False,
) -> torch.Tensor:
    conv = default_conversation.copy()
    roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

    if no_system_prompt:
        conv.system = ""

    # Skip the first message if it is not from human
    if messages[0]["from"] != "human":
        messages = messages[1:]

    # Add a generation prompt if needed
    if add_generation_prompt:
        messages.append({"from": "gpt", "value": None})

    conv.messages = []
    for turn, message in enumerate(messages):
        role = roles[message["from"]]
        assert role == conv.roles[turn % 2]
        if overrides is not None and message["from"] in overrides:
            conv.append_message(role, overrides[message["from"]])
        else:
            conv.append_message(role, message["value"])

    return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")


def tokenize_conversation(
    messages: Sequence[Dict[str, str]],
    tokenizer: transformers.PreTrainedTokenizer,
    add_generation_prompt: bool = False,
    overrides: Optional[Dict[str, str]] = None,
    no_system_prompt: bool = False,
) -> torch.Tensor:
    # Normalize the conversation before tokenization
    for message in messages:
        message["value"] = message["value"].strip()

    if default_conversation.sep_style != SeparatorStyle.AUTO:
        return tokenize_conversation_legacy(
            messages,
            tokenizer,
            add_generation_prompt=add_generation_prompt,
            overrides=overrides,
            no_system_prompt=no_system_prompt,
        )

    conversation = []
    for m in messages:
        message = {}
        if m["from"] == "human":
            message["role"] = "user"
        elif m["from"] == "gpt":
            message["role"] = "assistant"
        else:
            raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.")

        message["content"] = m["value"]
        if overrides is not None and m["from"] in overrides:
            message["content"] = overrides[m["from"]]
        conversation.append(message)

    if no_system_prompt:
        conversation = [{"role": "system", "content": ""}] + conversation

    text = tokenizer.apply_chat_template(
        conversation,
        add_generation_prompt=add_generation_prompt,
        tokenize=False,
    )
    return tokenizer_image_token(text, tokenizer, return_tensors="pt")


def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
    _maybe_add_sentinel_token(tokenizer)
    template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})

    stop_tokens = {tokenizer.eos_token}
    for k in range(template.size(0) - 1):
        if template[k] == tokenizer.sentinel_token_id:
            stop_token = tokenizer.decode(template[k + 1])
            stop_tokens.add(stop_token)
    return list(stop_tokens)


def context_length_extension(config):
    orig_ctx_len = getattr(config, "max_position_embeddings", None)
    model_max_length = getattr(config, "model_max_length", None)
    if orig_ctx_len and model_max_length > orig_ctx_len:
        print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}")
        scaling_factor = float(math.ceil(model_max_length / orig_ctx_len))
        config.rope_scaling = {"type": "linear", "factor": scaling_factor}
    return config


def build_llm_and_tokenizer(
    model_name_or_path: str,
    config: PretrainedConfig,
    attn_implementation=None,
    model_max_length=None,
    *args,
    **kwargs,
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
    # print(model_name_or_path)
    llm_cfg = AutoConfig.from_pretrained(model_name_or_path)
    llm_cfg._attn_implementation = attn_implementation
    llm_cfg.model_max_length = model_max_length
    if model_max_length is not None:
        context_length_extension(llm_cfg)

    # Quantization related
    quantization_restore_from_checkpoint = False

    if quantization_restore_from_checkpoint:
        fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None)

        llm = AutoModelForCausalLM.from_pretrained(
            fp8_model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
        )
    else:
        if is_deepspeed_zero3_enabled():
            # NOTE: found by wei, need to pop out device_map when using zero3
            kwargs.pop("device_map")
        llm = AutoModelForCausalLM.from_pretrained(
            model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
        )
    # NOTE(ligeng): not sure whether it affects the training
    # packing.patch(llm)

    # Locate the tokenizer.
    llm_path = model_name_or_path
    if not has_tokenizer(llm_path):
        llm_path = osp.join(llm_path, "llm")
    if not has_tokenizer(llm_path):
        raise ValueError(f"Cannot find tokenizer in {llm_path}.")

    tokenizer = AutoTokenizer.from_pretrained(llm_path, padding_side="right", use_fast=True, legacy=False)
    if model_max_length is not None:
        tokenizer.model_max_length = model_max_length

    # Load chat template if specified.
    if getattr(config, "chat_template", None) is not None:
        print(f"Using chat template: {config.chat_template}")
        fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
        if not os.path.exists(fpath):
            fpath = os.path.join(os.path.dirname(model_name_or_path), f"{config.chat_template}.jinja")
        with open(fpath) as fd:
            chat_template = fd.read()
        tokenizer.chat_template = chat_template.replace("    ", "").replace("\n", "")

    # Set stop tokens for the tokenizer
    tokenizer.stop_tokens = infer_stop_tokens(tokenizer)
    tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens)

    # Add media tokens to the tokenizer
    tokenizer.media_tokens = MEDIA_TOKENS
    tokenizer.media_token_ids = {}
    for name, token in MEDIA_TOKENS.items():
        tokenizer.add_tokens([token], special_tokens=True)
        tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token)

    # TODO(ligeng): is this necessary for llava?
    config.hidden_size = llm.config.hidden_size
    return llm, tokenizer
